from pipelines.prompta.learner.java_utils.dfa import TTTLearnerDFA
from pipelines.prompta.oracle import BaseOracle
from pipelines.prompta.utils import save_dfa
from prompta.utils.java_libs import AcexAnalyzers
from .base_learner import BaseLearner


class TTTLearner(BaseLearner):
    ID = "TTT"
    
    def __init__(self, oracle: BaseOracle, exp_dir):
        super().__init__(oracle, exp_dir)
        self.learner = TTTLearnerDFA(self.oracle.jalphabet, oracle, AcexAnalyzers.BINARY_SEARCH_BWD)
    
    def learn(self):
        self.learner.startLearning()
        while True:
            hypothesis = self.learner.getHypothesisModel()
            ce = self.oracle.check_conjecture(hypothesis, 'DefaultQuery')
            if ce is None:
                break
            self.learner.refineHypothesis(ce)
        save_dfa(self.get_dfa_save_path(), self.learner.getHypothesisModel(), self.oracle.jalphabet)
        return self.learner.getHypothesisModel()


